"""
Run script: functional genome wide association analysis (FGWAS) pipeline (step 1: MVCM)
Usage: python ./fgwas_step1.py ./data/ ./result/

Author: Chao Huang (chaohuang.stat@gmail.com)
Last update: 2017-12-18
"""

import sys
# import os
import time
from scipy.io import loadmat, savemat
import numpy as np
from numpy.linalg import inv
from data_prepro import preprocess
from mvcm import mvcm_1d, mvcm_2d, mvcm_3d

"""
installed all the libraries above
"""


def run_step1(input_dir, output_dir):

    """
    Run the commandline script for FGWAS (step 1: MVCM).

    :param
        input_dir (str): full path to the data folder
        output_dir (str): full path to the output folder
    """

    """+++++++++++++++++++++++++++++++++++"""
    print(""" Step 0. load dataset """)
    print("+++++++Read the imaging data+++++++")
    img_file_name = input_dir + "img_data.mat"
    mat = loadmat(img_file_name)
    img_data = mat['img_data']
    if len(img_data.shape) == 2:
        img_data = img_data.reshape(img_data.shape[0], img_data.shape[1], 1)
    n, l, m = img_data.shape
    print("The matrix dimension of image data is " + str(img_data.shape))
    print("+++++++Read the imaging coordinate data+++++++")
    coord_file_name = input_dir + "coord_data.txt"
    coord_data = np.loadtxt(coord_file_name)
    if len(coord_data.shape) == 1:
        coord_data = coord_data.reshape(coord_data.shape[0], 1)
    print("The matrix dimension of coordinate data is " + str(coord_data.shape))
    print("+++++++Read the covariate data+++++++")
    design_data_file_name = input_dir + "design_data.txt"
    design_data = np.loadtxt(design_data_file_name)
    print("The matrix dimension of covariate data is " + str(design_data.shape))

    # read the covariate type
    var_type_file_name = input_dir + "var_type.txt"
    var_type = np.loadtxt(var_type_file_name)
    var_type = np.array([int(i) for i in var_type])

    print("+++++++++Matrix preparing and Data preprocessing++++++++")
    print("+++++++Construct the imaging response, design, coordinate matrix: normalization+++++++")
    y_design, x_design, coord_data = preprocess(coord_data, img_data, design_data, var_type)
    d = coord_data.shape[1]
    p = x_design.shape[1]
    print("The dimension of normalized design matrix is " + str(x_design.shape))

    """+++++++++++++++++++++++++++++++++++"""
    print(""" Step 1. fit the multivariate varying coefficient model (MVCM) under H0 """)
    start_1 = time.time()

    # calculate the hat matrix
    hat_mat = np.dot(np.dot(x_design, inv(np.dot(x_design.T, x_design) + np.eye(p) * 0.0001)), x_design.T)
    """ find the optimal bandwidth & smooth the response data & fit MVCM """
    if d == 1:
        [h_opt, smy_design, resy_design, efit_eta] = mvcm_1d(coord_data, y_design, hat_mat)
    elif d == 2:
        [h_opt, smy_design, resy_design, efit_eta] = mvcm_2d(coord_data, y_design, hat_mat)
    else:
        [h_opt, smy_design, resy_design, efit_eta] = mvcm_3d(coord_data, y_design, hat_mat)
    res_eta = y_design * 0
    for mii in range(m):
        res_eta[:, :, mii] = resy_design[:, :, mii] - efit_eta[:, :, mii]
        print("The bound of the residual is "
              "[" + str(np.min(res_eta[:, :, mii])) + ", " + str(np.max(res_eta[:, :, mii])) + "]")

    inv_sig_eta = np.zeros((l, m, m))
    qr_smy_mat = np.zeros((n, n))
    for lii in range(l):
        inv_sig_eta[lii, :, :] = inv(np.dot(np.squeeze(efit_eta[:, lii, :]).T, np.squeeze(efit_eta[:, lii, :])) / n
                                     + np.eye(m) * 0.000001)
        qr_smy_mat = qr_smy_mat + np.dot(np.dot(np.squeeze(smy_design[:, lii, :]),
                                                inv_sig_eta[lii, :, :]), np.squeeze(smy_design[:, lii, :]).T) / l

    end_1 = time.time()
    print("Elapsed time in Step 1 is ", end_1 - start_1)

    # save results in temp folder for next step
    temp_dir = output_dir + "temp/"
    x_design_file_name = temp_dir + "x_design.mat"
    savemat(x_design_file_name, mdict={'x_design': x_design})
    coord_data_file_name = temp_dir + "coord_data.mat"
    savemat(coord_data_file_name, mdict={'coord_data': coord_data})
    y_design_file_name = temp_dir + "y_design.mat"
    savemat(y_design_file_name, mdict={'y_design': y_design})
    h_opt_file_name = temp_dir + "h_opt.mat"
    savemat(h_opt_file_name, mdict={'h_opt': h_opt})
    resy_design_file_name = temp_dir + "resy_design.mat"
    savemat(resy_design_file_name, mdict={'resy_design': resy_design})
    smy_design_file_name = temp_dir + "smy_design.mat"
    savemat(smy_design_file_name, mdict={'smy_design': smy_design})
    qr_smy_mat_file_name = temp_dir + "qr_smy_mat.mat"
    savemat(qr_smy_mat_file_name, mdict={'qr_smy_mat': qr_smy_mat})
    efit_eta_file_name = temp_dir + "efit_eta.mat"
    savemat(efit_eta_file_name, mdict={'efit_eta': efit_eta})
    inv_sig_eta_file_name = temp_dir + "inv_sig_eta.mat"
    savemat(inv_sig_eta_file_name, mdict={'inv_sig_eta': inv_sig_eta})
    hat_mat_file_name = temp_dir + "hat_mat.mat"
    savemat(hat_mat_file_name, mdict={'hat_mat': hat_mat})


if __name__ == '__main__':
    input_dir0 = sys.argv[1]
    output_dir0 = sys.argv[2]
    run_step1(input_dir0, output_dir0)
